import json

from datasets import load_dataset

from .geoquery_dataset import get_preprocessed_geoquery

def get_dataset(dataset="gsm8k",load_from_local=True):
    if(load_from_local == False):
        if dataset == "gsm8k":
            ds_train = load_dataset("openai/gsm8k", "main", split="train")
            ds_test = load_dataset("openai/gsm8k", "main", split="test")
            return ds_train["question"], ds_train["answer"], ds_test["question"], ds_test["answer"]
        else:
            raise ValueError("Invalid dataset.")
    elif(load_from_local == True):
        if(dataset == "gsm8k"):
            with open("data/gsm8k/gsm8k_train.json", 'r', encoding='utf-8') as file:
                ds_train = json.load(file)
            with open("data/gsm8k/gsm8k_test.json", 'r', encoding='utf-8') as file:
                ds_test = json.load(file)
            if isinstance(ds_train, list) and all(isinstance(item, dict) for item in ds_train) and isinstance(ds_test, list) and all(isinstance(item, dict) for item in ds_test):
                return [data["question"] for data in ds_train], [data["answer"] for data in ds_train], [data["question"] for data in ds_test], [data["answer"] for data in ds_test]
            else:
                raise ValueError("JSON file content is not a list of dictionaries")
        elif(dataset.startswith("geo880")):
            split_method = dataset.split("-")[1]
            ds_train, ds_test = get_preprocessed_geoquery(split_method=split_method)
            return ds_train["source"], ds_train["target"], ds_test["source"], ds_test["target"]
